Density Ratios¶

In [1]:
import numpy as np
import seaborn as sns
#Generate Random Samples from a Mixture of two normal distributions
# Define the parameters for the two normal distributions
mean1 = 0
variance1 = 1

mean2 = 5
variance2 = 2

# Generate random draws from the mixture distribution
n = 1000
weights = [0.95, 0.05]  # Equal weights for the two distributions

np.random.seed(42)

# Generate random indices to select which distribution to sample from
indices = np.random.choice([0, 1], size=n, p=weights)

# Generate random samples from the mixture distribution
p = np.zeros(n)
for i in range(n):
    if indices[i] == 0:
        p[i] = np.random.normal(mean1, np.sqrt(variance1))
    else:
        p[i] = np.random.normal(mean2, np.sqrt(variance2))

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2, figsize=(12, 4))

# Plot histogram of samples
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')

# Plot rug plot of samples
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')

plt.tight_layout()
plt.show()
No description has been provided for this image
In [2]:
import numpy as np

def generate_mixture_samples(means, variances, weights, n):
    # Generate random indices to select which distribution to sample from
    indices = np.random.choice(len(means), size=n, p=weights)
    
    # Generate random samples from the mixture distribution
    samples = np.zeros(n)
    for i in range(n):
        samples[i] = np.random.normal(means[indices[i]], np.sqrt(variances[indices[i]]))
    
    return samples
In [4]:
import numpy as np
from scipy.stats import norm

def compute_density_ratio(candidate_means, candidate_variances, candidate_weights, true_means, true_variances, true_weights):
    # Define the grid of values
    grid = np.linspace(-5, 5, 1000)
    
    # Compute the density ratio for each value in the grid
    density_ratio = np.zeros_like(grid)
    for i, value in enumerate(grid):
        candidate_pdf = np.sum([weight * norm.pdf(value, mean, np.sqrt(variance)) for mean, variance, weight in zip(candidate_means, candidate_variances, candidate_weights)])
        true_pdf = np.sum([weight * norm.pdf(value, mean, np.sqrt(variance)) for mean, variance, weight in zip(true_means, true_variances, true_weights)])
        density_ratio[i] = candidate_pdf / true_pdf
    
    return density_ratio
In [5]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([1],[1],[1],1000)
dr = compute_density_ratio([1],[1],[1],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))

# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))

# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')

# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)

# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))

# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)

import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))

# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()

# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()

# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))

# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')




plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`.
  warnings.warn(
No description has been provided for this image
In [6]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([0],[1],[1],1000)
dr = compute_density_ratio([0],[1],[1],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))

# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))

# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')

# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)

# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))

# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)

import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))

# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()

# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()

# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))

# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')




plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`.
  warnings.warn(
No description has been provided for this image
In [ ]:
means = np.linspace(-5, 5, 25)
variances = np.linspace(0.001, 3, 25)
import itertools
# Create a cartesian product of means and variances
grid = list(itertools.product(means, variances))

losses = np.zeros((len(means), len(variances)))

for i, (mean, variance) in enumerate(grid):
    q = generate_mixture_samples([mean], [variance], [1], 1000)
    #dr = compute_density_ratio([mean], [variance], [1], [0, 5], [1, 2], [.95, .05])
    x = np.concatenate((p, q))
    y = np.concatenate((np.ones_like(p), np.zeros_like(q)))
    logreg = LogisticRegression(penalty='none')
    logreg.fit(x.reshape(-1, 1), y)
    y_pred = logreg.predict_proba(x.reshape(-1, 1))
    loss = -log_loss(y, y_pred)
    losses[i // len(variances), i % len(variances)] = loss
In [10]:
# Find the point with the minimum loss
min_idx = np.argmin(losses)
min_row, min_col = np.unravel_index(min_idx, losses.shape)
best_mean = means[min_row]
best_variance = variances[min_col]

# Create a meshgrid for contour plotting
X, Y = np.meshgrid(variances, means)

plt.figure(figsize=(8, 6))
contour = plt.contourf(X, Y, losses, levels=20, cmap='viridis')
plt.colorbar(contour)
plt.plot(best_variance, best_mean, 'bX', markersize=15, markeredgewidth=3, label='Min Loss')
plt.xlabel('Variance')
plt.ylabel('Mean')
plt.title('Contour Plot of Losses')
plt.legend()
plt.show()
No description has been provided for this image
In [11]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([0],[2.5],[1],1000)
dr = compute_density_ratio([0],[1],[2.5],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))

# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))

# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')

# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)

# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))

# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)

import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))

# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()

# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()

# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))

# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')




plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`.
  warnings.warn(
No description has been provided for this image
In [12]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([0,5],[1,2],[.95,.05],1000)
dr = compute_density_ratio([0,5],[1,2],[.95,.05],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))

# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))

# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')

# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)

# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))

# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)

import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))

# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()

# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()

# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))

# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')




plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`.
  warnings.warn(
No description has been provided for this image

Simple f-GAN¶

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def sample_real_data(batch_size):
    """
    Generate samples from a mixture of 9 (3x3 grid) 2D Gaussians.
    Each Gaussian component has a center that is uniformly chosen from:
      (-1,-1), (-1,0), (-1,1),
       (0,-1),  (0,0),  (0,1),
       (1,-1),  (1,0),  (1,1).
    Each component has a standard deviation.
    """
    centers = np.array([[-1, -1], [-1, 0], [-1, 1],
                        [ 0, -1], [ 0, 0], [ 0, 1],
                        [ 1, -1], [ 1, 0], [ 1, 1]])
    num_components = centers.shape[0]
    indices = np.random.choice(num_components, size=batch_size)
    chosen_centers = centers[indices]
    std = 0.1
    samples = chosen_centers + np.random.randn(batch_size, 2) * std
    return torch.tensor(samples, dtype=torch.float32)
In [2]:
# Sample 10,000 points from the real data distribution.
samples = sample_real_data(10000)
x = samples[:, 0]
y = samples[:, 1]

# Set up the plot.
plt.figure(figsize=(8, 8))
# Use seaborn's kdeplot to create a density plot.
sns.kdeplot(x=x, y=y, cmap="viridis", fill=True, thresh=0, levels=100)
plt.title("Density of Real Data (10,000 Samples)")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
No description has been provided for this image
In [3]:
# -----------------------
# 1. Define a generic MLP (like the JAX MLP class)
# -----------------------
class MLP(nn.Module):
    def __init__(self, input_dim, features):
        """
        Constructs an MLP that applies a linear layer followed by ReLU for all but the final layer.
        Args:
            input_dim: Dimension of the input.
            features: List of integers, where each element corresponds to the output size of a dense layer.
                      The final element does not get a nonlinearity.
        """
        super(MLP, self).__init__()
        layers = []
        in_dim = input_dim
        for i, out_dim in enumerate(features):
            layers.append(nn.Linear(in_dim, out_dim))
            if i < len(features) - 1:
                layers.append(nn.ReLU(inplace=True))
            in_dim = out_dim
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)
In [4]:
# -----------------------
# 3. Define loss functions using log-sigmoid (replicates the JAX loss computation)
# -----------------------
def log_sigmoid(x):
    # Returns log(sigmoid(x))
    return torch.log(torch.sigmoid(x))

def discriminator_loss(D, G, real_examples, latents):
    """
    Computes the discriminator loss as:
        loss = mean( -log_sigmoid(D(real)) - log_sigmoid(-D(G(latents))) )
    """
    real_logits = D(real_examples)
    fake_examples = G(latents)
    fake_logits = D(fake_examples)
    loss_real = - log_sigmoid(real_logits)
    loss_fake = - log_sigmoid(-fake_logits)
    return torch.mean(loss_real + loss_fake)

def generator_loss(D, G, latents):
    """
    Computes the generator loss as:
        loss = mean( - log_sigmoid(D(G(latents))) )
    """
    fake_examples = G(latents)
    fake_logits = D(fake_examples)
    loss = - log_sigmoid(fake_logits)
    return torch.mean(loss)
In [5]:
# -----------------------
# 4. Training loop replicating the JAX training steps with SGD
# -----------------------
def train_gan(num_iters=20001, batch_size=512, latent_size=32, lr=0.05, n_save=2000, draw_contours=False, device='cpu'):
    device = torch.device(device)
    
    # Create the discriminator and generator.
    # Discriminator: input_dim 2, hidden layers: 25, 25, output_dim: 1
    # Generator: input_dim latent_size, hidden layers: 25, 25, output_dim: 2
    D = MLP(input_dim=2, features=[128, 128, 128, 1]).to(device)
    G = MLP(input_dim=latent_size, features=[128,128,128,2]).to(device)
    
    # Set up SGD optimizers replicating the JAX SGD with lr=0.05.
    optimizer_D = optim.SGD(D.parameters(), lr=lr)
    optimizer_G = optim.SGD(G.parameters(), lr=lr)
    
    # Prepare a fixed test latent vector for evaluation (10,000 samples)
    test_latents = torch.randn(10000, latent_size, device=device)
    
    history = []  # will store tuples: (iteration, fake_examples, disc_contour, disc_loss, gen_loss)
    
    for i in range(num_iters):
        # Sample minibatch of real examples (shape: [batch_size, 2])
        real_examples = sample_real_data(batch_size).to(device)
        # Sample minibatch of latent vectors from a standard normal (shape: [batch_size, latent_size])
        latents = torch.randn(batch_size, latent_size, device=device)
        
        # -- Discriminator step --
        optimizer_D.zero_grad()
        loss_D = discriminator_loss(D, G, real_examples, latents)
        loss_D.backward()
        optimizer_D.step()
        
        # -- Generator step --
        optimizer_G.zero_grad()
        # We use the same minibatch of latents here.
        loss_G = generator_loss(D, G, latents)
        loss_G.backward()
        optimizer_G.step()
        
        if i % n_save == 0:
            print(f"i = {i}, Discriminator Loss = {loss_D.item()}, Generator Loss = {loss_G.item()}")
            with torch.no_grad():
                fake_examples = G(test_latents)
            disc_contour = None
            if draw_contours:
                # Optional: compute a contour measure over some grid if desired.
                # (The original code computes: -D(pairs) + log_sigmoid(D(pairs)))
                # For simplicity, we leave this as None.
                disc_contour = None
            history.append((i, fake_examples.cpu(), disc_contour, loss_D.item(), loss_G.item()))
    
    return D, G, history
In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G, history = train_gan(num_iters=20001, batch_size=512, latent_size=32, lr=0.05,
                            n_save=2000, draw_contours=False, device=device)
i = 0, Discriminator Loss = 1.384421944618225, Generator Loss = 0.654618501663208
i = 2000, Discriminator Loss = 1.2403438091278076, Generator Loss = 0.6935384273529053
i = 4000, Discriminator Loss = 1.0644993782043457, Generator Loss = 1.0018454790115356
i = 6000, Discriminator Loss = 1.0645021200180054, Generator Loss = 1.893606424331665
i = 8000, Discriminator Loss = 0.7953554391860962, Generator Loss = 1.5759837627410889
i = 10000, Discriminator Loss = 0.8092895746231079, Generator Loss = 1.2489733695983887
i = 12000, Discriminator Loss = 0.8678987622261047, Generator Loss = 1.3148552179336548
i = 14000, Discriminator Loss = 0.9495569467544556, Generator Loss = 1.1578044891357422
i = 16000, Discriminator Loss = 0.8955338001251221, Generator Loss = 1.198028564453125
i = 18000, Discriminator Loss = 0.8887498378753662, Generator Loss = 1.2266111373901367
i = 20000, Discriminator Loss = 0.9115414619445801, Generator Loss = 1.2572792768478394
In [7]:
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming 'history' is available from training
# Each element in history is a tuple: (iteration, fake_samples, disc_loss, gen_loss)
for entry in history:
    iteration, fake_samples, disc_contour, disc_loss, gen_loss = entry
    
    # Create a figure for each snapshot
    plt.figure(figsize=(6, 6))
    
    # Use Seaborn's kdeplot to compute and display the 2D kernel density estimate.
    sns.kdeplot(x=fake_samples[:, 0], y=fake_samples[:, 1],
                fill=True, levels=50, cmap="viridis")
    
    # Add labels and a title with iteration and losses information.
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"Estimated Density at Iteration {iteration}\n"
              f"Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
    plt.tight_layout()
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Simple w-GAN¶

In [20]:
class MLP(nn.Module):
    def __init__(self, input_dim, features):
        """
        Constructs an MLP with hidden layers specified by the list `features`.
        A ReLU activation is applied after each layer except the final one.
        """
        super(MLP, self).__init__()
        layers = []
        in_dim = input_dim
        for i, out_dim in enumerate(features):
            layers.append(nn.Linear(in_dim, out_dim))
            if i < len(features) - 1:
                layers.append(nn.ReLU(inplace=True))
            in_dim = out_dim
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

# -----------------------
# 3. Define Wasserstein Losses
# -----------------------
def critic_loss(D, G, real_examples, latents):
    """
    Computes the Wasserstein critic loss:
      L_D = -(E[D(real)] - E[D(fake)])
    """
    real_scores = D(real_examples)
    fake_scores = D(G(latents))
    return fake_scores.mean() - real_scores.mean()

def generator_loss(D, G, latents):
    """
    Computes the generator loss for WGAN:
      L_G = -E[D(fake)]
    """
    return - D(G(latents)).mean()
In [21]:
# -----------------------
# 4. Training Loop for WGAN with n_disc updates per iteration
# -----------------------
def train_wgan(num_iters=20001, batch_size=512, latent_size=32, lr=0.05,
               n_save=2000, n_disc=5, clip_value=0.01, device='cpu', draw_contours = False):
    device = torch.device(device)
    
    # Instantiate the critic and generator.
    D = MLP(input_dim=2, features=[128,128,128, 1]).to(device)  # Critic: 2D -> score (no sigmoid!)
    G = MLP(input_dim=latent_size, features=[128,128,128, 2]).to(device)  # Generator: latent -> 2D output
    
    # Set up optimizers using SGD as in the original JAX code.
    optimizer_D = optim.SGD(D.parameters(), lr=lr)
    optimizer_G = optim.SGD(G.parameters(), lr=lr)
    
    # Fixed test latents for monitoring (10,000 samples)
    test_latents = torch.randn(10000, latent_size, device=device)
    
    history = []  # List to store snapshots: (iteration, fake_samples, critic_loss, generator_loss)
    
    for i in range(num_iters):
        # --- Critic (Discriminator) update: n_disc iterations ---
        for _ in range(n_disc):
            real_examples = sample_real_data(batch_size).to(device)
            latents = torch.randn(batch_size, latent_size, device=device)
            
            optimizer_D.zero_grad()
            loss_D = critic_loss(D, G, real_examples, latents)
            loss_D.backward()
            optimizer_D.step()
            
            # Weight clipping to enforce Lipschitz condition.
            for p in D.parameters():
                p.data.clamp_(-clip_value, clip_value)
        
        # --- Generator update (one step after n_disc updates) ---
        latents = torch.randn(batch_size, latent_size, device=device)
        optimizer_G.zero_grad()
        loss_G = generator_loss(D, G, latents)
        loss_G.backward()
        optimizer_G.step()
        
        
        if i % n_save == 0:
            print(f"i = {i}, Discriminator Loss = {loss_D.item()}, Generator Loss = {loss_G.item()}")
            with torch.no_grad():
                fake_examples = G(test_latents)
            disc_contour = None
            if draw_contours:
                # Optional: compute a contour measure over some grid if desired.
                # (The original code computes: -D(pairs) + log_sigmoid(D(pairs)))
                # For simplicity, we leave this as None.
                disc_contour = None
            history.append((i, fake_examples.cpu(), disc_contour, loss_D.item(), loss_G.item()))
    
    return D, G, history
In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G, history = train_wgan(num_iters=20001, batch_size=10000, latent_size=32, lr=0.05,
                            n_save=2000, n_disc=2, clip_value=.01, device=device)
i = 0, Discriminator Loss = -4.794273991137743e-07, Generator Loss = 0.00012903052265755832
i = 2000, Discriminator Loss = -1.542569589219056e-05, Generator Loss = 4.844953946303576e-05
i = 4000, Discriminator Loss = -4.3357867980375886e-05, Generator Loss = -3.837565236608498e-05
i = 6000, Discriminator Loss = -8.668069494888186e-05, Generator Loss = -0.0002593372482806444
i = 8000, Discriminator Loss = -0.00016360956942662597, Generator Loss = -0.0002497853129170835
i = 10000, Discriminator Loss = -0.0002925149165093899, Generator Loss = -0.0002160591830033809
i = 12000, Discriminator Loss = -0.0005131656071171165, Generator Loss = -0.0001401335612172261
i = 14000, Discriminator Loss = -0.0008085581357590854, Generator Loss = -7.504215318476781e-05
i = 16000, Discriminator Loss = -0.0011433582985773683, Generator Loss = -9.464036338613369e-06
i = 18000, Discriminator Loss = -0.0014178385026752949, Generator Loss = 2.4077523903542897e-06
i = 20000, Discriminator Loss = -0.0015906720655038953, Generator Loss = -1.5236424587783404e-05
In [23]:
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming 'history' is available from training
# Each element in history is a tuple: (iteration, fake_samples, disc_loss, gen_loss)
for entry in history:
    iteration, fake_samples, disc_contour, disc_loss, gen_loss = entry
    
    # Create a figure for each snapshot
    plt.figure(figsize=(6, 6))
    
    # Use Seaborn's kdeplot to compute and display the 2D kernel density estimate.
    sns.kdeplot(x=fake_samples[:, 0], y=fake_samples[:, 1],
                fill=True, levels=50, cmap="viridis")
    
    # Add labels and a title with iteration and losses information.
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"Estimated Density at Iteration {iteration}\n"
              f"Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
    plt.tight_layout()
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

WGAN-GP¶

In [24]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------
# 1. Define a generic MLP (used for both critic and generator)
# -----------------------
class MLP(nn.Module):
    def __init__(self, input_dim, features):
        """
        Constructs an MLP with hidden layers specified by the list `features`.
        A ReLU activation is applied after each layer except the final one.
        """
        super(MLP, self).__init__()
        layers = []
        in_dim = input_dim
        for i, out_dim in enumerate(features):
            layers.append(nn.Linear(in_dim, out_dim))
            if i < len(features) - 1:
                layers.append(nn.ReLU(inplace=True))
            in_dim = out_dim
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

# -----------------------
# 2. Real Data Distribution: 9-component Mixture of 2D Gaussians
# -----------------------
def sample_real_data(batch_size, std=0.1):
    """
    Samples a batch of 2D points drawn from one of 9 Gaussian components arranged on a 3x3 grid.
    The centers are:
       (-1,-1), (-1, 0), (-1, 1),
       ( 0,-1), ( 0, 0), ( 0, 1),
       ( 1,-1), ( 1, 0), ( 1, 1)
    """
    centers = np.array([[-1, -1], [-1, 0], [-1, 1],
                        [ 0, -1], [ 0, 0], [ 0, 1],
                        [ 1, -1], [ 1, 0], [ 1, 1]])
    num_components = centers.shape[0]
    indices = np.random.choice(num_components, size=batch_size)
    chosen_centers = centers[indices]
    samples = chosen_centers + np.random.randn(batch_size, 2) * std
    return torch.tensor(samples, dtype=torch.float)

# -----------------------
# 3. Define Gradient Penalty and Wasserstein Losses (WGAN-GP)
# -----------------------
def gradient_penalty(D, real_samples, fake_samples, device):
    """Computes the gradient penalty for interpolated samples."""
    batch_size = real_samples.size(0)
    # Sample interpolation coefficient uniformly between 0 and 1.
    alpha = torch.rand(batch_size, 1, device=device)
    alpha = alpha.expand_as(real_samples)
    # Create interpolated samples.
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.requires_grad_(True)
    
    # Compute critic scores on interpolated samples.
    d_interpolates = D(interpolates)
    
    # For each sample, create a tensor of ones with the same shape as the output.
    ones = torch.ones(d_interpolates.size(), device=device)
    # Compute gradients of critic scores with respect to the interpolated samples.
    gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
                                    grad_outputs=ones,
                                    create_graph=True,
                                    retain_graph=True,
                                    only_inputs=True)[0]
    # Flatten gradients per sample.
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    # Compute penalty as (gradient norm - 1)^2.
    gp = ((gradient_norm - 1) ** 2).mean()
    return gp

def critic_loss(D, G, real_examples, latents, lambda_gp, device):
    """
    Computes the WGAN critic loss with gradient penalty:
      L_D = E[D(fake)] - E[D(real)] + lambda_gp * GP,
    where GP is the gradient penalty.
    """
    fake_examples = G(latents)
    real_scores = D(real_examples)
    fake_scores = D(fake_examples)
    loss = fake_scores.mean() - real_scores.mean()
    
    # Compute gradient penalty on interpolated samples.
    gp = gradient_penalty(D, real_examples, fake_examples, device)
    loss += lambda_gp * gp
    return loss

def generator_loss(D, G, latents):
    """
    Computes the generator loss for WGAN:
      L_G = -E[D(G(latents))]
    """
    return - D(G(latents)).mean()

# -----------------------
# 4. Training Loop for WGAN-GP with n_disc updates per iteration
# -----------------------
def train_wgan_gp(num_iters=20001, batch_size=512, latent_size=32, lr=0.05,
                  n_save=2000, n_disc=5, lambda_gp=10, device='cpu', draw_contours = False):
    device = torch.device(device)
    
    # Instantiate the critic and generator.
    D = MLP(input_dim=2, features=[128,128,128,1]).to(device)  # Critic: 2D -> score (no activation)
    G = MLP(input_dim=latent_size, features=[128, 128,128, 2]).to(device)  # Generator: latent -> 2D output
    
    # Set up optimizers (using SGD as in the JAX code).
    optimizer_D = optim.SGD(D.parameters(), lr=lr)
    optimizer_G = optim.SGD(G.parameters(), lr=lr)
    
    # Fixed test latent vectors for monitoring (10,000 samples).
    test_latents = torch.randn(10000, latent_size, device=device)
    
    history = []  # To store snapshots: (iteration, fake_samples, critic_loss, generator_loss)
    
    for i in range(num_iters):
        # --- Critic (Discriminator) update: perform n_disc updates ---
        for _ in range(n_disc):
            real_examples = sample_real_data(batch_size).to(device)
            latents = torch.randn(batch_size, latent_size, device=device)
            
            optimizer_D.zero_grad()
            loss_D = critic_loss(D, G, real_examples, latents, lambda_gp, device)
            loss_D.backward()
            optimizer_D.step()
        
        # --- Generator update (one update after n_disc critic updates) ---
        latents = torch.randn(batch_size, latent_size, device=device)
        optimizer_G.zero_grad()
        loss_G = generator_loss(D, G, latents)
        loss_G.backward()
        optimizer_G.step()
        
        if i % n_save == 0:
            print(f"i = {i}, Discriminator Loss = {loss_D.item()}, Generator Loss = {loss_G.item()}")
            with torch.no_grad():
                fake_examples = G(test_latents)
            disc_contour = None
            if draw_contours:
                # Optional: compute a contour measure over some grid if desired.
                # (The original code computes: -D(pairs) + log_sigmoid(D(pairs)))
                # For simplicity, we leave this as None.
                disc_contour = None
            history.append((i, fake_examples.cpu(), disc_contour, loss_D.item(), loss_G.item()))
    
    return D, G, history
In [29]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G, history = train_wgan_gp(num_iters=20001, batch_size=10000, latent_size=32, lr=0.05,
                                n_save=2000, n_disc=2, lambda_gp=1, device=device)
i = 0, Discriminator Loss = 0.9039061069488525, Generator Loss = 0.1036360114812851
i = 2000, Discriminator Loss = -0.1359843909740448, Generator Loss = 5.910412788391113
i = 4000, Discriminator Loss = -0.07140640914440155, Generator Loss = 4.818828582763672
i = 6000, Discriminator Loss = -0.02494090050458908, Generator Loss = 4.749511241912842
i = 8000, Discriminator Loss = -0.004812396131455898, Generator Loss = 4.237672328948975
i = 10000, Discriminator Loss = -0.0016526570543646812, Generator Loss = 3.8261005878448486
i = 12000, Discriminator Loss = -0.002830632496625185, Generator Loss = 3.5270280838012695
i = 14000, Discriminator Loss = 0.014997678808867931, Generator Loss = 3.673243761062622
i = 16000, Discriminator Loss = 0.0187930129468441, Generator Loss = 3.8711085319519043
i = 18000, Discriminator Loss = 0.023033861070871353, Generator Loss = 4.0552520751953125
i = 20000, Discriminator Loss = 0.026469510048627853, Generator Loss = 4.31480073928833
In [30]:
import matplotlib.pyplot as plt
import seaborn as sns

# Assuming 'history' is available from training
# Each element in history is a tuple: (iteration, fake_samples, disc_loss, gen_loss)
for entry in history:
    iteration, fake_samples, disc_contour, disc_loss, gen_loss = entry
    
    # Create a figure for each snapshot
    plt.figure(figsize=(6, 6))
    
    # Use Seaborn's kdeplot to compute and display the 2D kernel density estimate.
    sns.kdeplot(x=fake_samples[:, 0], y=fake_samples[:, 1],
                fill=True, levels=50, cmap="viridis")
    
    # Add labels and a title with iteration and losses information.
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"Estimated Density at Iteration {iteration}\n"
              f"Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
    plt.tight_layout()
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

f-GAN for Celebrity Faces (DCGAN)¶

In [1]:
import os
import pandas as pd

# Specify the folder containing your subset of images and the CSV file path
images_folder = 'subset_images'
csv_path = 'list_attr_celeba.csv'

# Read the CSV file into a DataFrame
df = pd.read_csv(csv_path)

# Get a set of image filenames present in the subset_images folder
existing_images = set(os.listdir(images_folder))

# Filter the DataFrame to only include rows where the image_id exists in the folder
df_subset = df[df['image_id'].isin(existing_images)]
df_subset = df_subset.reset_index(drop=True)

print("Total rows in original CSV:", len(df))
print("Total rows in subset CSV:", len(df_subset))

# Optionally, save the subset to a new CSV file for future use:
df_subset.to_csv("subset_list_attr_celeba.csv", index=False)

# Get attribute names from the CSV (all columns except the first "image_id")
attribute_names = list(df_subset.columns[1:])
print("Attribute Names:")
print(attribute_names)
Total rows in original CSV: 202599
Total rows in subset CSV: 15000
Attribute Names:
['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

class CelebASubsetDataset(Dataset):
    def __init__(self, images_dir, csv_file, transform=None):
        """
        Args:
            images_dir (str): Path to the directory containing the images.
            csv_file (str): Path to the CSV file with image attributes.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.images_dir = images_dir
        self.transform = transform
        
        # Read the CSV file into a DataFrame
        self.attr_df = pd.read_csv(csv_file)
        
        # Assuming the first column is 'image_id' and the rest are attributes, we store the attribute names.
        self.image_ids = self.attr_df['image_id'].values
        # Get attribute columns (all columns besides 'image_id')
        self.attributes = self.attr_df.drop(columns=['image_id']).values.astype('float32')
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        # Get the image file name and its corresponding attributes
        img_id = self.image_ids[idx]
        attr = self.attributes[idx]
        
        # Construct the full image path
        img_path = os.path.join(self.images_dir, img_id)
        
        # Open the image file and ensure it is in RGB mode.
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Convert the attributes to a tensor
        attr_tensor = torch.tensor(attr)
        
        return image, attr_tensor

# Update transform: Resize to 64x64, then ToTensor, and normalize to [-1, 1]
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Create the full dataset instance.
dataset = CelebASubsetDataset(
    images_dir='subset_images',
    csv_file='subset_list_attr_celeba.csv',
    transform=transform
)

# Define a split proportion for training and validation.
# For example, an 80/20 split:
train_size = int(0.999 * len(dataset))
valid_size = len(dataset) - train_size

# Split the dataset using random_split.
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

# Create DataLoaders for both splits.
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)

# Optional: Print out the sizes for confirmation.
print(f"Total dataset size: {len(dataset)}")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(valid_dataset)}")
Total dataset size: 15000
Training set size: 14985
Validation set size: 15
In [3]:
# =======================================================
# Define a helper module for reshaping in the Generator
# =======================================================
class View(nn.Module):
    def __init__(self, shape):
        """
        A simple layer to reshape tensors to the given shape.
        """
        super(View, self).__init__()
        self.shape = shape
    def forward(self, input):
        return input.view(*self.shape)
In [4]:
# =======================================================
# Generator: DCGAN-Style for 64x64 Images
# =======================================================
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64):
        """
        Generator mapping a latent vector (dimension nz) to a 3x64x64 image.
        It first projects the latent vector to a tensor of shape (ngf*8, 8, 8),
        then upsamples through three layers:
           8x8 -> 16x16,
           16x16 -> 32x32,
           32x32 -> 64x64.
        The final convolution produces a 3-channel image.
        """
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            # Project latent vector and reshape:
            nn.Linear(nz, ngf * 8 * 8 * 8),
            nn.BatchNorm1d(ngf * 8 * 8 * 8),
            nn.ReLU(True),
            View((-1, ngf * 8, 8, 8)),  # Shape: (ngf*8, 8, 8)
            # Upsample: 8x8 -> 16x16
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # Upsample: 16x16 -> 32x32
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # Upsample: 32x32 -> 64x64
            nn.ConvTranspose2d(ngf * 2, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()  # Output values in the range [-1,1]
        )
        
    def forward(self, input):
        return self.main(input)
In [5]:
# =======================================================
# Discriminator: DCGAN-Style for 64x64 Images
# =======================================================
class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        """
        The discriminator takes a 3x64x64 image and outputs a scalar probability.
        It uses four convolutional layers:
          - 64x64 -> 32x32,
          - 32x32 -> 16x16,
          - 16x16 -> 8x8,
          - 8x8 -> 4x4.
        A final convolution with kernel size 4 collapses the 4x4 feature map to a single value.
        """
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: 3 x 64 x 64 -> ndf x 32 x 32
            nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 32x32 -> 16x16; output channels: ndf*2
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 16x16 -> 8x8; output channels: ndf*4
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 8x8 -> 4x4; output channels: ndf*8
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # Final layer: 4x4 -> 1, using kernel size 4
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()  # For probability output
        )
        
    def forward(self, input):
        out = self.main(input)
        return out.view(-1)
In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from PIL import Image
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

# =======================================================
# Training Procedure using BCE Loss
# =======================================================
def train_f_gan(num_epochs=50, nz=100, device='cpu'):
    device = torch.device(device)
    
    # Instantiate generator and discriminator.
    G = Generator(nz=nz, ngf=64).to(device)
    D = Discriminator(ndf=64).to(device)
    
    # Use Adam optimizers with typical parameters for DCGAN.
    optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Use binary cross entropy loss.
    criterion = nn.BCELoss()
    
    # Create a fixed set of latent vectors (25 images) for visualization.
    fixed_noise = torch.randn(25, nz, device=device)
    
    # Variables to store last batch losses for display.
    last_loss_D = None
    last_loss_G = None
    
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(train_dataloader):
            images = images.to(device)  # [B, 3, 224, 224]
            batch_size = images.size(0)
            
            # Create labels.
            real_labels = torch.ones(batch_size, device=device)
            fake_labels = torch.zeros(batch_size, device=device)
            
            # -------------------------
            # Train Discriminator
            # -------------------------
            optimizerD.zero_grad()
            # Real images forward-pass.
            outputs_real = D(images)
            loss_real = criterion(outputs_real, real_labels)
            
            # Generate fake images.
            noise = torch.randn(batch_size, nz, device=device)
            fake_images = G(noise)
            outputs_fake = D(fake_images.detach())
            loss_fake = criterion(outputs_fake, fake_labels)
            
            loss_D = loss_real + loss_fake
            loss_D.backward()
            optimizerD.step()
            
            # -------------------------
            # Train Generator
            # -------------------------
            optimizerG.zero_grad()
            # Generate fake images; we want these to be classified as real.
            noise = torch.randn(batch_size, nz, device=device)
            fake_images = G(noise)
            outputs = D(fake_images)
            loss_G = criterion(outputs, real_labels)
            loss_G.backward()
            optimizerG.step()
            
            if i % 50 == 0:
                print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(train_dataloader)}], "
                      f"Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
            
            last_loss_D = loss_D.item()
            last_loss_G = loss_G.item()
        
        # End of each epoch: generate 25 images from fixed latent vectors and display them.
        with torch.no_grad():
            fake_samples = G(fixed_noise).detach().cpu()
        # Create a grid of 25 images.
        grid = utils.make_grid(fake_samples, nrow=5, normalize=True)
        plt.figure(figsize=(8, 8))
        plt.imshow(grid.permute(1, 2, 0).numpy())
        plt.axis('off')
        plt.title(f"Epoch {epoch} | Loss_D: {last_loss_D:.4f} | Loss_G: {last_loss_G:.4f}")
        plt.show()
        plt.close()
        print(f"End of Epoch {epoch} completed.")
    
    return D, G
In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G = train_f_gan(num_epochs=20, nz=100, device=device)
Epoch [0/20], Batch [0/235], Loss_D: 1.4343, Loss_G: 1.4272
Epoch [0/20], Batch [50/235], Loss_D: 0.0313, Loss_G: 8.7729
Epoch [0/20], Batch [100/235], Loss_D: 0.0303, Loss_G: 14.2098
Epoch [0/20], Batch [150/235], Loss_D: 0.2522, Loss_G: 7.2179
Epoch [0/20], Batch [200/235], Loss_D: 0.4485, Loss_G: 6.5541
No description has been provided for this image
End of Epoch 0 completed.
Epoch [1/20], Batch [0/235], Loss_D: 0.2547, Loss_G: 6.2205
Epoch [1/20], Batch [50/235], Loss_D: 0.1791, Loss_G: 3.9664
Epoch [1/20], Batch [100/235], Loss_D: 0.0452, Loss_G: 5.0688
Epoch [1/20], Batch [150/235], Loss_D: 0.3047, Loss_G: 6.7193
Epoch [1/20], Batch [200/235], Loss_D: 0.0989, Loss_G: 4.0727
No description has been provided for this image
End of Epoch 1 completed.
Epoch [2/20], Batch [0/235], Loss_D: 0.0806, Loss_G: 4.5329
Epoch [2/20], Batch [50/235], Loss_D: 0.1334, Loss_G: 5.1288
Epoch [2/20], Batch [100/235], Loss_D: 0.1641, Loss_G: 3.6295
Epoch [2/20], Batch [150/235], Loss_D: 1.7095, Loss_G: 1.4473
Epoch [2/20], Batch [200/235], Loss_D: 0.1942, Loss_G: 4.6550
No description has been provided for this image
End of Epoch 2 completed.
Epoch [3/20], Batch [0/235], Loss_D: 0.0985, Loss_G: 3.7031
Epoch [3/20], Batch [50/235], Loss_D: 0.1984, Loss_G: 3.6030
Epoch [3/20], Batch [100/235], Loss_D: 0.4367, Loss_G: 10.5271
Epoch [3/20], Batch [150/235], Loss_D: 0.1521, Loss_G: 2.9989
Epoch [3/20], Batch [200/235], Loss_D: 0.1850, Loss_G: 3.4543
No description has been provided for this image
End of Epoch 3 completed.
Epoch [4/20], Batch [0/235], Loss_D: 0.4477, Loss_G: 1.8774
Epoch [4/20], Batch [50/235], Loss_D: 0.2037, Loss_G: 4.1418
Epoch [4/20], Batch [100/235], Loss_D: 0.0637, Loss_G: 3.5348
Epoch [4/20], Batch [150/235], Loss_D: 0.6581, Loss_G: 1.7092
Epoch [4/20], Batch [200/235], Loss_D: 0.1727, Loss_G: 2.8104
No description has been provided for this image
End of Epoch 4 completed.
Epoch [5/20], Batch [0/235], Loss_D: 0.1548, Loss_G: 3.2654
Epoch [5/20], Batch [50/235], Loss_D: 0.7366, Loss_G: 5.2767
Epoch [5/20], Batch [100/235], Loss_D: 0.0485, Loss_G: 4.6805
Epoch [5/20], Batch [150/235], Loss_D: 0.2277, Loss_G: 4.8771
Epoch [5/20], Batch [200/235], Loss_D: 0.0831, Loss_G: 3.9929
No description has been provided for this image
End of Epoch 5 completed.
Epoch [6/20], Batch [0/235], Loss_D: 0.0328, Loss_G: 4.9790
Epoch [6/20], Batch [50/235], Loss_D: 0.1184, Loss_G: 3.6737
Epoch [6/20], Batch [100/235], Loss_D: 0.0574, Loss_G: 4.9615
Epoch [6/20], Batch [150/235], Loss_D: 0.4830, Loss_G: 1.6109
Epoch [6/20], Batch [200/235], Loss_D: 1.1322, Loss_G: 11.0200
No description has been provided for this image
End of Epoch 6 completed.
Epoch [7/20], Batch [0/235], Loss_D: 3.3290, Loss_G: 7.9481
Epoch [7/20], Batch [50/235], Loss_D: 0.0750, Loss_G: 3.3906
Epoch [7/20], Batch [100/235], Loss_D: 0.0546, Loss_G: 4.9041
Epoch [7/20], Batch [150/235], Loss_D: 0.0656, Loss_G: 5.2365
Epoch [7/20], Batch [200/235], Loss_D: 0.2962, Loss_G: 4.2396
No description has been provided for this image
End of Epoch 7 completed.
Epoch [8/20], Batch [0/235], Loss_D: 0.1624, Loss_G: 3.8747
Epoch [8/20], Batch [50/235], Loss_D: 0.1166, Loss_G: 2.9114
Epoch [8/20], Batch [100/235], Loss_D: 6.3864, Loss_G: 7.3692
Epoch [8/20], Batch [150/235], Loss_D: 0.2464, Loss_G: 3.1736
Epoch [8/20], Batch [200/235], Loss_D: 0.1069, Loss_G: 3.8590
No description has been provided for this image
End of Epoch 8 completed.
Epoch [9/20], Batch [0/235], Loss_D: 0.0554, Loss_G: 3.5248
Epoch [9/20], Batch [50/235], Loss_D: 0.1436, Loss_G: 3.4984
Epoch [9/20], Batch [100/235], Loss_D: 0.0329, Loss_G: 6.2885
Epoch [9/20], Batch [150/235], Loss_D: 0.4614, Loss_G: 4.2605
Epoch [9/20], Batch [200/235], Loss_D: 0.4134, Loss_G: 3.6515
No description has been provided for this image
End of Epoch 9 completed.
Epoch [10/20], Batch [0/235], Loss_D: 0.0390, Loss_G: 4.0490
Epoch [10/20], Batch [50/235], Loss_D: 0.0531, Loss_G: 3.9692
Epoch [10/20], Batch [100/235], Loss_D: 0.0499, Loss_G: 5.3644
Epoch [10/20], Batch [150/235], Loss_D: 0.4878, Loss_G: 3.4199
Epoch [10/20], Batch [200/235], Loss_D: 0.8591, Loss_G: 6.4835
No description has been provided for this image
End of Epoch 10 completed.
Epoch [11/20], Batch [0/235], Loss_D: 0.4459, Loss_G: 2.0155
Epoch [11/20], Batch [50/235], Loss_D: 0.8302, Loss_G: 1.0756
Epoch [11/20], Batch [100/235], Loss_D: 0.1194, Loss_G: 3.5714
Epoch [11/20], Batch [150/235], Loss_D: 0.1289, Loss_G: 4.6166
Epoch [11/20], Batch [200/235], Loss_D: 0.0646, Loss_G: 4.2834
No description has been provided for this image
End of Epoch 11 completed.
Epoch [12/20], Batch [0/235], Loss_D: 0.1541, Loss_G: 6.3816
Epoch [12/20], Batch [50/235], Loss_D: 0.1974, Loss_G: 5.4419
Epoch [12/20], Batch [100/235], Loss_D: 0.8444, Loss_G: 1.1366
Epoch [12/20], Batch [150/235], Loss_D: 0.1047, Loss_G: 3.8173
Epoch [12/20], Batch [200/235], Loss_D: 0.0680, Loss_G: 4.3182
No description has been provided for this image
End of Epoch 12 completed.
Epoch [13/20], Batch [0/235], Loss_D: 1.0144, Loss_G: 4.6319
Epoch [13/20], Batch [50/235], Loss_D: 0.0882, Loss_G: 3.3174
Epoch [13/20], Batch [100/235], Loss_D: 0.0852, Loss_G: 4.0259
Epoch [13/20], Batch [150/235], Loss_D: 0.0449, Loss_G: 4.9833
Epoch [13/20], Batch [200/235], Loss_D: 0.0284, Loss_G: 6.1240
No description has been provided for this image
End of Epoch 13 completed.
Epoch [14/20], Batch [0/235], Loss_D: 0.0258, Loss_G: 6.4724
Epoch [14/20], Batch [50/235], Loss_D: 0.0082, Loss_G: 5.8401
Epoch [14/20], Batch [100/235], Loss_D: 0.3763, Loss_G: 3.7295
Epoch [14/20], Batch [150/235], Loss_D: 0.1908, Loss_G: 2.3354
Epoch [14/20], Batch [200/235], Loss_D: 0.3295, Loss_G: 2.6553
No description has been provided for this image
End of Epoch 14 completed.
Epoch [15/20], Batch [0/235], Loss_D: 0.1220, Loss_G: 3.8772
Epoch [15/20], Batch [50/235], Loss_D: 0.1942, Loss_G: 3.5176
Epoch [15/20], Batch [100/235], Loss_D: 0.1182, Loss_G: 3.3420
Epoch [15/20], Batch [150/235], Loss_D: 0.1327, Loss_G: 3.4660
Epoch [15/20], Batch [200/235], Loss_D: 1.4664, Loss_G: 10.9529
No description has been provided for this image
End of Epoch 15 completed.
Epoch [16/20], Batch [0/235], Loss_D: 0.1515, Loss_G: 3.7957
Epoch [16/20], Batch [50/235], Loss_D: 0.2007, Loss_G: 4.0982
Epoch [16/20], Batch [100/235], Loss_D: 0.1016, Loss_G: 4.5341
Epoch [16/20], Batch [150/235], Loss_D: 0.0535, Loss_G: 5.0302
Epoch [16/20], Batch [200/235], Loss_D: 0.0535, Loss_G: 3.7612
No description has been provided for this image
End of Epoch 16 completed.
Epoch [17/20], Batch [0/235], Loss_D: 0.0298, Loss_G: 4.4096
Epoch [17/20], Batch [50/235], Loss_D: 0.0466, Loss_G: 5.5681
Epoch [17/20], Batch [100/235], Loss_D: 0.9867, Loss_G: 1.6225
Epoch [17/20], Batch [150/235], Loss_D: 0.3657, Loss_G: 2.0744
Epoch [17/20], Batch [200/235], Loss_D: 0.2883, Loss_G: 3.9713
No description has been provided for this image
End of Epoch 17 completed.
Epoch [18/20], Batch [0/235], Loss_D: 0.1419, Loss_G: 4.5721
Epoch [18/20], Batch [50/235], Loss_D: 0.1509, Loss_G: 4.5356
Epoch [18/20], Batch [100/235], Loss_D: 0.1508, Loss_G: 3.9809
Epoch [18/20], Batch [150/235], Loss_D: 0.0704, Loss_G: 3.7084
Epoch [18/20], Batch [200/235], Loss_D: 0.3868, Loss_G: 5.3765
No description has been provided for this image
End of Epoch 18 completed.
Epoch [19/20], Batch [0/235], Loss_D: 0.1649, Loss_G: 4.5945
Epoch [19/20], Batch [50/235], Loss_D: 0.0568, Loss_G: 4.1682
Epoch [19/20], Batch [100/235], Loss_D: 0.0500, Loss_G: 4.2615
Epoch [19/20], Batch [150/235], Loss_D: 0.0479, Loss_G: 5.1583
Epoch [19/20], Batch [200/235], Loss_D: 0.0492, Loss_G: 5.5410
No description has been provided for this image
End of Epoch 19 completed.

WGAN-GP¶

In [23]:
# =======================================================
# Helper Module: View (for reshaping tensor in Generator)
# =======================================================
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape
    def forward(self, input):
        return input.view(*self.shape)
In [24]:
# =======================================================
# Generator: DCGAN-Style for 64x64 Images
# =======================================================
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64):
        """
        Generator mapping a latent vector (dimension nz) to a 3x64x64 image.
        It first projects the latent vector to a tensor of shape (ngf*8, 8, 8),
        then upsamples through three layers:
           8x8 -> 16x16,
           16x16 -> 32x32,
           32x32 -> 64x64.
        The final convolution produces a 3-channel image.
        """
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            # Project latent vector and reshape:
            nn.Linear(nz, ngf * 8 * 8 * 8),
            nn.BatchNorm1d(ngf * 8 * 8 * 8),
            nn.ReLU(True),
            View((-1, ngf * 8, 8, 8)),  # Shape: (ngf*8, 8, 8)
            # Upsample: 8x8 -> 16x16
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # Upsample: 16x16 -> 32x32
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # Upsample: 32x32 -> 64x64
            nn.ConvTranspose2d(ngf * 2, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()  # Output values in the range [-1,1]
        )
        
    def forward(self, input):
        return self.main(input)
In [25]:
# =======================================================
# Discriminator: DCGAN-Style for 64x64 Images
# =======================================================
class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        """
        The discriminator takes a 3x64x64 image and outputs a scalar probability.
        It uses four convolutional layers:
          - 64x64 -> 32x32,
          - 32x32 -> 16x16,
          - 16x16 -> 8x8,
          - 8x8 -> 4x4.
        A final convolution with kernel size 4 collapses the 4x4 feature map to a single value.
        """
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: 3 x 64 x 64 -> ndf x 32 x 32
            nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 32x32 -> 16x16; output channels: ndf*2
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 16x16 -> 8x8; output channels: ndf*4
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 8x8 -> 4x4; output channels: ndf*8
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # Final layer: 4x4 -> 1, using kernel size 4
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            # No sigmoid
        )
        
    def forward(self, input):
        out = self.main(input)
        return out.view(-1)
In [32]:
import torch
import torch.autograd as autograd

def compute_gradient_penalty(D, real_samples, fake_samples, device):
    batch_size = real_samples.size(0)
    # Generate a random epsilon in [0,1] for each sample. Shape: [B, 1, 1, 1]
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    
    # Create interpolated samples:
    interpolates = epsilon * real_samples + (1 - epsilon) * fake_samples
    interpolates.requires_grad_(True)
    
    # Compute discriminator output on the interpolated samples:
    d_interpolates = D(interpolates)
    
    # For each sample, set the gradient output to 1.
    grad_outputs = torch.ones(d_interpolates.size(), device=device)
    
    # Compute gradients of the outputs with respect to the interpolated samples.
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Flatten gradients to shape [batch_size, -1] and compute L2 norm for each sample.
    gradients = gradients.view(batch_size, -1)
    grad_norms = torch.sqrt(torch.sum(gradients ** 2, dim=1))
    
    # Compute the gradient penalty as the mean squared deviation of the gradients' norm from 1.
    gradient_penalty = torch.mean((grad_norms - 1) ** 2)
    
    return gradient_penalty
In [33]:
# =======================================================
# Loss Functions for WGAN-GP
# =======================================================
def critic_loss(D, G, real_samples, latents, lambda_gp, device):
    fake_samples = G(latents)
    # Critic scores: higher for real samples.
    real_scores = D(real_samples)
    fake_scores = D(fake_samples)
    loss = fake_scores.mean() - real_scores.mean()
    gp = compute_gradient_penalty(D, real_samples, fake_samples, device)
    loss += lambda_gp * gp
    return loss

def generator_loss(D, G, latents):
    # Generator loss: try to maximize the critic's score on fake images.
    return - D(G(latents)).mean()
In [36]:
def train_wgan_gp(num_epochs=50, nz=100, n_critic=5, lambda_gp=10, device='cpu'):
    device = torch.device(device)
    
    G = Generator(nz=nz, ngf=64).to(device)
    D = Discriminator(ndf=64).to(device)
    
    # Parameters
    LR = 1e-4          # Initial learning rate
    MIN_LR = 1e-6      # Minimum learning rate
    DECAY_FACTOR = 1.00004  # Decay factor per epoch


    # Set up Adam optimizers with beta1=0.5 (as in your code)
    optimizerD = optim.Adam(D.parameters(), lr=LR, betas=(0.5, 0.999))
    optimizerG = optim.Adam(G.parameters(), lr=LR, betas=(0.5, 0.999))

    # Define a lambda function for learning rate decay.
    def lr_lambda(epoch):
        # This returns the multiplicative factor that gets multiplied by the initial lr.
        # It ensures that lr never goes below MIN_LR.
        return max((1 / DECAY_FACTOR) ** epoch, MIN_LR / LR)

    # Set up the learning rate schedulers for both optimizers.
    schedulerD = optim.lr_scheduler.LambdaLR(optimizerD, lr_lambda=lr_lambda)
    schedulerG = optim.lr_scheduler.LambdaLR(optimizerG, lr_lambda=lr_lambda)
    
    fixed_noise = torch.randn(25, nz, device=device)
    
    last_loss_D = None
    last_loss_G = None
    
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(train_dataloader):
            images = images.to(device)
            batch_size = images.size(0)
            
            # Update critic n_critic times.
            for _ in range(n_critic):
                noise = torch.randn(batch_size, nz, device=device)
                optimizerD.zero_grad()
                loss_D = critic_loss(D, G, images, noise, lambda_gp, device)
                loss_D.backward()
                optimizerD.step()
            
            # Update generator once.
            noise = torch.randn(batch_size, nz, device=device)
            optimizerG.zero_grad()
            loss_G = generator_loss(D, G, noise)
            loss_G.backward()
            optimizerG.step()
            
            if i % 50 == 0:
                print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(train_dataloader)}], "
                      f"Critic Loss: {loss_D.item():.4f}, Generator Loss: {loss_G.item():.4f}")
            
            last_loss_D = loss_D.item()
            last_loss_G = loss_G.item()
        
        # End of epoch: generate and display 25 images from fixed noise.
        with torch.no_grad():
            fake_samples = G(fixed_noise).detach().cpu()
        grid = utils.make_grid(fake_samples, nrow=5, normalize=True)
        plt.figure(figsize=(8, 8))
        plt.imshow(grid.permute(1, 2, 0).numpy())
        plt.axis('off')
        plt.title(f"Epoch {epoch} | Critic Loss: {last_loss_D:.4f} | Generator Loss: {last_loss_G:.4f}")
        plt.show()
        plt.close()
        print(f"End of Epoch {epoch} completed.")

        # At the end of the epoch:
        schedulerD.step()
        schedulerG.step()
    
    return D, G
In [39]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G = train_wgan_gp(num_epochs=25, nz=100, n_critic=5, lambda_gp=10, device=device)
Epoch [0/25], Batch [0/235], Critic Loss: -7.1815, Generator Loss: 3.3995
Epoch [0/25], Batch [50/235], Critic Loss: -319.1111, Generator Loss: 161.0143
Epoch [0/25], Batch [100/235], Critic Loss: -575.8983, Generator Loss: 293.8505
Epoch [0/25], Batch [150/235], Critic Loss: -947.6006, Generator Loss: 470.6101
Epoch [0/25], Batch [200/235], Critic Loss: -1310.2970, Generator Loss: 664.8742
No description has been provided for this image
End of Epoch 0 completed.
Epoch [1/25], Batch [0/235], Critic Loss: -1603.1932, Generator Loss: 804.1373
Epoch [1/25], Batch [50/235], Critic Loss: -1037.7480, Generator Loss: 103.4770
Epoch [1/25], Batch [100/235], Critic Loss: -1613.4746, Generator Loss: 573.4979
Epoch [1/25], Batch [150/235], Critic Loss: -2640.9807, Generator Loss: 1264.7196
Epoch [1/25], Batch [200/235], Critic Loss: -3086.8433, Generator Loss: 1482.9570
No description has been provided for this image
End of Epoch 1 completed.
Epoch [2/25], Batch [0/235], Critic Loss: 229.4240, Generator Loss: 1392.4490
Epoch [2/25], Batch [50/235], Critic Loss: -28.0986, Generator Loss: 1401.1266
Epoch [2/25], Batch [100/235], Critic Loss: -3754.0598, Generator Loss: 1808.0903
Epoch [2/25], Batch [150/235], Critic Loss: -4414.3828, Generator Loss: 2138.6362
Epoch [2/25], Batch [200/235], Critic Loss: 29.6658, Generator Loss: 1225.9301
No description has been provided for this image
End of Epoch 2 completed.
Epoch [3/25], Batch [0/235], Critic Loss: 17.0579, Generator Loss: 1219.2595
Epoch [3/25], Batch [50/235], Critic Loss: 5.0561, Generator Loss: 1213.5356
Epoch [3/25], Batch [100/235], Critic Loss: -0.4950, Generator Loss: 1210.5740
Epoch [3/25], Batch [150/235], Critic Loss: -2.3364, Generator Loss: 1205.5253
Epoch [3/25], Batch [200/235], Critic Loss: -5.4682, Generator Loss: 1204.8875
No description has been provided for this image
End of Epoch 3 completed.
Epoch [4/25], Batch [0/235], Critic Loss: -1.7772, Generator Loss: 1206.0137
Epoch [4/25], Batch [50/235], Critic Loss: -3.2213, Generator Loss: 1200.0872
Epoch [4/25], Batch [100/235], Critic Loss: -7.3166, Generator Loss: 1199.0438
Epoch [4/25], Batch [150/235], Critic Loss: -6.2481, Generator Loss: 1196.1101
Epoch [4/25], Batch [200/235], Critic Loss: -6.9179, Generator Loss: 1193.1678
No description has been provided for this image
End of Epoch 4 completed.
Epoch [5/25], Batch [0/235], Critic Loss: -5.9082, Generator Loss: 1195.7135
Epoch [5/25], Batch [50/235], Critic Loss: -7.8350, Generator Loss: 1186.9421
Epoch [5/25], Batch [100/235], Critic Loss: -9.1983, Generator Loss: 1174.1233
Epoch [5/25], Batch [150/235], Critic Loss: -8.1507, Generator Loss: 1183.1907
Epoch [5/25], Batch [200/235], Critic Loss: -7.0861, Generator Loss: 1164.1135
No description has been provided for this image
End of Epoch 5 completed.
Epoch [6/25], Batch [0/235], Critic Loss: -6.0164, Generator Loss: 1167.7697
Epoch [6/25], Batch [50/235], Critic Loss: -8.0003, Generator Loss: 1160.3273
Epoch [6/25], Batch [100/235], Critic Loss: -8.2061, Generator Loss: 1158.1433
Epoch [6/25], Batch [150/235], Critic Loss: -8.6398, Generator Loss: 1143.9407
Epoch [6/25], Batch [200/235], Critic Loss: -12.2942, Generator Loss: 1142.6199
No description has been provided for this image
End of Epoch 6 completed.
Epoch [7/25], Batch [0/235], Critic Loss: -9.0067, Generator Loss: 1134.8313
Epoch [7/25], Batch [50/235], Critic Loss: -11.3545, Generator Loss: 1120.7852
Epoch [7/25], Batch [100/235], Critic Loss: -10.2414, Generator Loss: 1113.6353
Epoch [7/25], Batch [150/235], Critic Loss: -13.3492, Generator Loss: 1095.5884
Epoch [7/25], Batch [200/235], Critic Loss: -14.7403, Generator Loss: 1083.0547
No description has been provided for this image
End of Epoch 7 completed.
Epoch [8/25], Batch [0/235], Critic Loss: -12.3538, Generator Loss: 1081.8322
Epoch [8/25], Batch [50/235], Critic Loss: -11.8157, Generator Loss: 1062.2549
Epoch [8/25], Batch [100/235], Critic Loss: -13.7031, Generator Loss: 1061.2454
Epoch [8/25], Batch [150/235], Critic Loss: -15.6746, Generator Loss: 1038.0032
Epoch [8/25], Batch [200/235], Critic Loss: -12.2785, Generator Loss: 1036.6870
No description has been provided for this image
End of Epoch 8 completed.
Epoch [9/25], Batch [0/235], Critic Loss: -12.3723, Generator Loss: 1021.6680
Epoch [9/25], Batch [50/235], Critic Loss: -9.6453, Generator Loss: 1016.1304
Epoch [9/25], Batch [100/235], Critic Loss: -13.5530, Generator Loss: 1014.8564
Epoch [9/25], Batch [150/235], Critic Loss: -15.0983, Generator Loss: 991.1080
Epoch [9/25], Batch [200/235], Critic Loss: -13.8649, Generator Loss: 984.4841
No description has been provided for this image
End of Epoch 9 completed.
Epoch [10/25], Batch [0/235], Critic Loss: -9.6877, Generator Loss: 986.0001
Epoch [10/25], Batch [50/235], Critic Loss: -11.9478, Generator Loss: 973.5349
Epoch [10/25], Batch [100/235], Critic Loss: -12.6790, Generator Loss: 959.2227
Epoch [10/25], Batch [150/235], Critic Loss: -14.4752, Generator Loss: 952.8516
Epoch [10/25], Batch [200/235], Critic Loss: -12.7882, Generator Loss: 948.1614
No description has been provided for this image
End of Epoch 10 completed.
Epoch [11/25], Batch [0/235], Critic Loss: -10.8146, Generator Loss: 943.7811
Epoch [11/25], Batch [50/235], Critic Loss: -14.7517, Generator Loss: 940.8361
Epoch [11/25], Batch [100/235], Critic Loss: -16.9774, Generator Loss: 943.5679
Epoch [11/25], Batch [150/235], Critic Loss: -16.0654, Generator Loss: 939.0920
Epoch [11/25], Batch [200/235], Critic Loss: -20.3614, Generator Loss: 936.4711
No description has been provided for this image
End of Epoch 11 completed.
Epoch [12/25], Batch [0/235], Critic Loss: -18.1832, Generator Loss: 928.4172
Epoch [12/25], Batch [50/235], Critic Loss: -15.1205, Generator Loss: 928.4465
Epoch [12/25], Batch [100/235], Critic Loss: -15.1267, Generator Loss: 939.8590
Epoch [12/25], Batch [150/235], Critic Loss: -17.2297, Generator Loss: 934.1822
Epoch [12/25], Batch [200/235], Critic Loss: -19.3050, Generator Loss: 939.5249
No description has been provided for this image
End of Epoch 12 completed.
Epoch [13/25], Batch [0/235], Critic Loss: -12.7700, Generator Loss: 928.9781
Epoch [13/25], Batch [50/235], Critic Loss: -15.1378, Generator Loss: 934.3294
Epoch [13/25], Batch [100/235], Critic Loss: -14.2550, Generator Loss: 935.8308
Epoch [13/25], Batch [150/235], Critic Loss: -21.9916, Generator Loss: 936.1085
Epoch [13/25], Batch [200/235], Critic Loss: -17.9902, Generator Loss: 944.5997
No description has been provided for this image
End of Epoch 13 completed.
Epoch [14/25], Batch [0/235], Critic Loss: -14.3817, Generator Loss: 929.7521
Epoch [14/25], Batch [50/235], Critic Loss: -18.0552, Generator Loss: 944.4028
Epoch [14/25], Batch [100/235], Critic Loss: -11.2363, Generator Loss: 935.4798
Epoch [14/25], Batch [150/235], Critic Loss: -20.1454, Generator Loss: 942.7131
Epoch [14/25], Batch [200/235], Critic Loss: -22.4876, Generator Loss: 942.3488
No description has been provided for this image
End of Epoch 14 completed.
Epoch [15/25], Batch [0/235], Critic Loss: -16.9694, Generator Loss: 936.5730
Epoch [15/25], Batch [50/235], Critic Loss: -23.6465, Generator Loss: 943.0557
Epoch [15/25], Batch [100/235], Critic Loss: -13.8048, Generator Loss: 942.3694
Epoch [15/25], Batch [150/235], Critic Loss: -17.4215, Generator Loss: 935.6983
Epoch [15/25], Batch [200/235], Critic Loss: -18.4353, Generator Loss: 937.7614
No description has been provided for this image
End of Epoch 15 completed.
Epoch [16/25], Batch [0/235], Critic Loss: -8.8008, Generator Loss: 936.5846
Epoch [16/25], Batch [50/235], Critic Loss: -21.4678, Generator Loss: 955.0105
Epoch [16/25], Batch [100/235], Critic Loss: -12.5794, Generator Loss: 943.1880
Epoch [16/25], Batch [150/235], Critic Loss: -23.7656, Generator Loss: 939.2748
Epoch [16/25], Batch [200/235], Critic Loss: -26.7899, Generator Loss: 936.5636
No description has been provided for this image
End of Epoch 16 completed.
Epoch [17/25], Batch [0/235], Critic Loss: -11.8148, Generator Loss: 929.4254
Epoch [17/25], Batch [50/235], Critic Loss: -18.5867, Generator Loss: 941.4560
Epoch [17/25], Batch [100/235], Critic Loss: -11.9151, Generator Loss: 938.3667
Epoch [17/25], Batch [150/235], Critic Loss: -17.8414, Generator Loss: 943.4852
Epoch [17/25], Batch [200/235], Critic Loss: -14.1036, Generator Loss: 949.7693
No description has been provided for this image
End of Epoch 17 completed.
Epoch [18/25], Batch [0/235], Critic Loss: -28.9961, Generator Loss: 940.0487
Epoch [18/25], Batch [50/235], Critic Loss: -21.4629, Generator Loss: 939.6287
Epoch [18/25], Batch [100/235], Critic Loss: -21.2551, Generator Loss: 948.8262
Epoch [18/25], Batch [150/235], Critic Loss: -14.8244, Generator Loss: 939.2914
Epoch [18/25], Batch [200/235], Critic Loss: -16.4426, Generator Loss: 947.8840
No description has been provided for this image
End of Epoch 18 completed.
Epoch [19/25], Batch [0/235], Critic Loss: -18.4954, Generator Loss: 946.0103
Epoch [19/25], Batch [50/235], Critic Loss: -13.0874, Generator Loss: 941.6476
Epoch [19/25], Batch [100/235], Critic Loss: -8.4792, Generator Loss: 936.5409
Epoch [19/25], Batch [150/235], Critic Loss: -8.4111, Generator Loss: 948.5744
Epoch [19/25], Batch [200/235], Critic Loss: -8.8216, Generator Loss: 949.1342
No description has been provided for this image
End of Epoch 19 completed.
Epoch [20/25], Batch [0/235], Critic Loss: -9.7134, Generator Loss: 945.5421
Epoch [20/25], Batch [50/235], Critic Loss: -6.9775, Generator Loss: 948.1437
Epoch [20/25], Batch [100/235], Critic Loss: -16.3411, Generator Loss: 947.7766
Epoch [20/25], Batch [150/235], Critic Loss: -29.5035, Generator Loss: 953.3276
Epoch [20/25], Batch [200/235], Critic Loss: -12.2876, Generator Loss: 940.1945
No description has been provided for this image
End of Epoch 20 completed.
Epoch [21/25], Batch [0/235], Critic Loss: -14.1305, Generator Loss: 933.7478
Epoch [21/25], Batch [50/235], Critic Loss: -17.7026, Generator Loss: 942.4226
Epoch [21/25], Batch [100/235], Critic Loss: -18.2549, Generator Loss: 942.1077
Epoch [21/25], Batch [150/235], Critic Loss: -20.6455, Generator Loss: 939.1882
Epoch [21/25], Batch [200/235], Critic Loss: -12.8188, Generator Loss: 952.3013
No description has been provided for this image
End of Epoch 21 completed.
Epoch [22/25], Batch [0/235], Critic Loss: -12.8006, Generator Loss: 944.2827
Epoch [22/25], Batch [50/235], Critic Loss: -15.2925, Generator Loss: 948.0229
Epoch [22/25], Batch [100/235], Critic Loss: -18.3503, Generator Loss: 947.7360
Epoch [22/25], Batch [150/235], Critic Loss: -13.2577, Generator Loss: 941.4637
Epoch [22/25], Batch [200/235], Critic Loss: -12.0882, Generator Loss: 953.0192
No description has been provided for this image
End of Epoch 22 completed.
Epoch [23/25], Batch [0/235], Critic Loss: -1.3852, Generator Loss: 937.0220
Epoch [23/25], Batch [50/235], Critic Loss: -8.8576, Generator Loss: 954.6143
Epoch [23/25], Batch [100/235], Critic Loss: -13.4211, Generator Loss: 946.9574
Epoch [23/25], Batch [150/235], Critic Loss: -12.9877, Generator Loss: 948.9841
Epoch [23/25], Batch [200/235], Critic Loss: -17.1787, Generator Loss: 944.7324
No description has been provided for this image
End of Epoch 23 completed.
Epoch [24/25], Batch [0/235], Critic Loss: -13.0251, Generator Loss: 942.5323
Epoch [24/25], Batch [50/235], Critic Loss: -28.6021, Generator Loss: 955.1877
Epoch [24/25], Batch [100/235], Critic Loss: -19.4927, Generator Loss: 946.7385
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[39], line 2
      1 device = 'cuda' if torch.cuda.is_available() else 'cpu'
----> 2 D, G = train_wgan_gp(num_epochs=25, nz=100, n_critic=5, lambda_gp=10, device=device)

Cell In[36], line 41, in train_wgan_gp(num_epochs, nz, n_critic, lambda_gp, device)
     39 noise = torch.randn(batch_size, nz, device=device)
     40 optimizerD.zero_grad()
---> 41 loss_D = critic_loss(D, G, images, noise, lambda_gp, device)
     42 loss_D.backward()
     43 optimizerD.step()

Cell In[33], line 10, in critic_loss(D, G, real_samples, latents, lambda_gp, device)
      8 fake_scores = D(fake_samples)
      9 loss = fake_scores.mean() - real_scores.mean()
---> 10 gp = compute_gradient_penalty(D, real_samples, fake_samples, device)
     11 loss += lambda_gp * gp
     12 return loss

Cell In[32], line 14, in compute_gradient_penalty(D, real_samples, fake_samples, device)
     11 interpolates.requires_grad_(True)
     13 # Compute discriminator output on the interpolated samples:
---> 14 d_interpolates = D(interpolates)
     16 # For each sample, set the gradient output to 1.
     17 grad_outputs = torch.ones(d_interpolates.size(), device=device)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

Cell In[25], line 38, in Discriminator.forward(self, input)
     37 def forward(self, input):
---> 38     out = self.main(input)
     39     return out.view(-1)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:554, in Conv2d.forward(self, input)
    553 def forward(self, input: Tensor) -> Tensor:
--> 554     return self._conv_forward(input, self.weight, self.bias)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:549, in Conv2d._conv_forward(self, input, weight, bias)
    537 if self.padding_mode != "zeros":
    538     return F.conv2d(
    539         F.pad(
    540             input, self._reversed_padding_repeated_twice, mode=self.padding_mode
   (...)
    547         self.groups,
    548     )
--> 549 return F.conv2d(
    550     input, weight, bias, self.stride, self.padding, self.dilation, self.groups
    551 )

KeyboardInterrupt: